{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0562090a",
   "metadata": {},
   "source": [
    "### Annotate variant carriers by gene test code "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "57c03457",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd \n",
    "import pysam \n",
    "from io import StringIO\n",
    "from pybedtools import BedTool\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "559cb0cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "wkdir = \"/lustre/scratch126/humgen/projects/interval_rna/interval_rna_seq/thomasVDS/misexpression_v3\"\n",
    "wkdir_path = Path(wkdir)\n",
    "\n",
    "gene_id = \"ENSG00000280095\"\n",
    "chrom = \"chr21\"\n",
    "vcf_path= f\"/lustre/scratch126/humgen/projects/interval_wgs/final_release_freeze/gt_phased/interval_wgs.{chrom}.gt_phased.vcf.gz\"\n",
    "ran_ids_path = wkdir_path.joinpath(\"1_rna_seq_qc/aberrant_smpl_qc/smpls_pass_qc_4568.csv\")\n",
    "smpl_id_pass_qc_set = set(pd.read_csv(ran_ids_path, sep=\",\", header=None)[0])\n",
    "wgs_rna_paired_smpls_path = wkdir_path.joinpath(\"1_rna_seq_qc/wgs_rna_match/paired_wgs_rna_postqc_prioritise_wgs.tsv\")\n",
    "intersect_bed_path = wkdir_path.joinpath(\"4_vrnt_enrich/snp_indel_count_carriers/intersect_beds/chr21_gene_vrnts_intersect.tsv\")\n",
    "root_dir = wkdir_path.joinpath(\"4_vrnt_enrich/snp_indel_count_carriers/tests\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "cbe1a9d1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "chr21 number of variants in ENSG00000280095 windows: 2343\n",
      "Number of samples in VCF with RNA ID and passing QC: 2821\n",
      "Number of RNA IDs passing QC with paired sample in VCF: 2821\n"
     ]
    }
   ],
   "source": [
    "root_dir_path = Path(root_dir)\n",
    "root_dir_path.mkdir(parents=True, exist_ok=True)\n",
    "# load intersection bed file \n",
    "intersect_bed_df = pd.read_csv(intersect_bed_path, sep=\"\\t\")\n",
    "# extract entries with gene\n",
    "gene_id_intersect_bed_df = intersect_bed_df[intersect_bed_df.gene_id == gene_id]\n",
    "# generate set of variant IDs inside gene windows\n",
    "gene_id_intersect_no_dupl_df = gene_id_intersect_bed_df[[\"chrom_vrnt\", \"vrnt_id\", \"start_vrnt\", \"end_vrnt\"]].drop_duplicates()\n",
    "intersect_vrnt_ids_list = gene_id_intersect_no_dupl_df.vrnt_id.unique()\n",
    "print(f\"{chrom} number of variants in {gene_id} windows: {len(intersect_vrnt_ids_list)}\")\n",
    "\n",
    "vcf = pysam.VariantFile(vcf_path, mode = \"r\")\n",
    "# load RNA-EGAN ID linker and subset to samples passing QC\n",
    "wgs_rna_paired_smpls_df = pd.read_csv(wgs_rna_paired_smpls_path, sep=\"\\t\")\n",
    "egan_ids_with_rna = wgs_rna_paired_smpls_df[wgs_rna_paired_smpls_df.rna_id.isin(smpl_id_pass_qc_set)].egan_id.tolist()\n",
    "# subset vcf to samples with expression data \n",
    "vcf_samples = [sample for sample in vcf.header.samples]\n",
    "vcf_egan_ids_with_rna = set(egan_ids_with_rna).intersection(set(vcf_samples))\n",
    "vcf.subset_samples(vcf_egan_ids_with_rna)\n",
    "vcf_samples_with_rna = [sample for sample in vcf.header.samples]\n",
    "print(f\"Number of samples in VCF with RNA ID and passing QC: {len(vcf_samples_with_rna)}\")\n",
    "# subset egan ID and RNA ID linker file to samples with genotype calls and passing QC \n",
    "wgs_rna_paired_smpls_with_gts_df = wgs_rna_paired_smpls_df[wgs_rna_paired_smpls_df.egan_id.isin(vcf_samples_with_rna)]\n",
    "rna_id_pass_qc_with_gts = wgs_rna_paired_smpls_with_gts_df.rna_id.unique().tolist()\n",
    "print(f\"Number of RNA IDs passing QC with paired sample in VCF: {len(rna_id_pass_qc_with_gts)}\")\n",
    "\n",
    "vrnts_gt_intersect_dir = root_dir_path.joinpath(f\"vrnts_gts_intersect/{chrom}/\")\n",
    "Path(vrnts_gt_intersect_dir).mkdir(parents=True, exist_ok=True)\n",
    "vrnts_annotate_path = vrnts_gt_intersect_dir.joinpath(f\"{chrom}_{gene_id}_vrnts_gts_intersect.tsv\")\n",
    "\n",
    "all_gts_set = {(1, 0), (0, 1), (1, 1), (0, 0), (None, None)}\n",
    "carriers_lt_50perc = {(1, 0), (0, 1), (1, 1)}\n",
    "carriers_gt_50perc = {(1, 0), (0, 1), (0, 0)}\n",
    "with open(vrnts_annotate_path, \"w\") as f_out: \n",
    "    header='\\t'.join(['vrnt_id', 'egan_id', 'genotype', 'AF', 'vrnt_type'])\n",
    "    f_out.write(f\"{header}\\n\")\n",
    "    for vrnt_id in intersect_vrnt_ids_list:\n",
    "        # get chromosome, start and end of variant\n",
    "        chrom_vrnt, pos, end = [gene_id_intersect_no_dupl_df[gene_id_intersect_no_dupl_df.vrnt_id == vrnt_id][col].item() for col in [\"chrom_vrnt\", \"start_vrnt\", \"end_vrnt\"]]\n",
    "        records = vcf.fetch(f\"chr{chrom_vrnt}\", pos, end)\n",
    "        found_vrnt_id = False\n",
    "        for rec in records: \n",
    "            chrom, pos, ref, alt_alleles = rec.chrom, rec.pos, rec.ref, rec.alts\n",
    "            # check for multiallelic alleles\n",
    "            if len(alt_alleles) > 1: \n",
    "                raise ValueError(f\"Multiallelic entry in vcf at: {chrom}, {pos}\")\n",
    "            else: \n",
    "                alt = alt_alleles[0]\n",
    "            vcf_vrnt_id = f\"{chrom}:{pos}:{ref}:{alt}\"\n",
    "            if vrnt_id == vcf_vrnt_id:\n",
    "                found_vrnt_id = True \n",
    "                gts = [s[\"GT\"] for s in rec.samples.values()]\n",
    "                gts_set = set(gts)\n",
    "                # check genotypes \n",
    "                if not gts_set <= all_gts_set: \n",
    "                    raise ValueError(f\"Encountered different genotype at {chrom}, {pos}: {gts_set}.\")\n",
    "                # assign carrier genotypes \n",
    "                af = rec.info[\"AF\"]\n",
    "                if af < 0.5: \n",
    "                    carrier_gts = carriers_lt_50perc  \n",
    "                else: \n",
    "                    carrier_gts = carriers_gt_50perc  \n",
    "                # write if variant has carriers \n",
    "                if len(carrier_gts.intersection(gts_set)) > 0:\n",
    "                    if len(ref) == 1 and len(alt) == 1: \n",
    "                        vrnt_type = \"snp\"\n",
    "                    else:\n",
    "                        vrnt_type = \"indel\"\n",
    "                    gts_dict = {i:gt for i, gt in enumerate(gts) if gt in carrier_gts} \n",
    "                    for i in gts_dict.keys():\n",
    "                        line= '\\t'.join([vrnt_id, vcf.header.samples[i], str(gts_dict[i]), str(af), vrnt_type])\n",
    "                        f_out.write(f\"{line}\\n\")\n",
    "        if not found_vrnt_id: \n",
    "            raise ValueError(f\"Did not find {vrnt_id} in {vcf_path}\")\n",
    "f_out.close()\n",
    "vcf.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "645c6aa4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
